{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2020 The Google Research Authors.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Activation Clustering Model: Similar Examples, Concepts\n",
    "\n",
    "This notebook shows how to use an activation clustering model to discover training examples similar to a test example, and \"concepts\".\n",
    "\n",
    "Here we use a trained activation clustering model (in the `work_dir` directory) whose baseline model is a ResNet classification model trained on the CIFAR-10 dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "from activation_clustering import ac_model, utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Restore an activation clustering model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acm = ac_model.ACModel.restore('work_dir')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from activation_clustering import utils\n",
    "utils.get_activation_shapes(acm.baseline_model, acm.activation_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acm.activation_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The same dataset preprocessing as used in the baseline model training.\n",
    "def input_fn(batch_size, ds, label_key='label'):\n",
    "    dataset = ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)\n",
    "\n",
    "    def interface(batch):\n",
    "        features = tf.cast(batch['image'], tf.float32) / 255     \n",
    "        labels = batch[label_key]\n",
    "\n",
    "        return features, labels\n",
    "\n",
    "    return dataset.map(interface)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds = tfds.load(\n",
    "    'cifar10:3.*.*',\n",
    "    shuffle_files=False,\n",
    "    split='test'\n",
    ")\n",
    "\n",
    "test_ds = input_fn(batch_size=10000, ds=test_ds)\n",
    "\n",
    "test_features, test_labels =list(test_ds.take(1))[0]\n",
    "del test_ds\n",
    "\n",
    "test_features = test_features.numpy()\n",
    "test_labels = test_labels.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# An activation clustering model can be used as a surrogate model for its baseline model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Accuracy of the surrogate model\n",
    "print('surrogate model accuracy: ', acm.evaluate(features=test_features, y=test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fidelity: how much does the surrogate model agree with the baseline model\n",
    "baseline_labels = np.argmax(acm.baseline_model.predict(test_features), axis=-1)\n",
    "print('fidelity: ', acm.evaluate(features=test_features, y=baseline_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acm.predict_proba(features=test_features[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acm.clustering_predict_labels(features=test_features[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# accuracy of the baseline model\n",
    "np.sum(test_labels == baseline_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get similar training examples of random test examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the training features (images) for visualization\n",
    "train_ds = tfds.load(\n",
    "    'cifar10:3.*.*',\n",
    "    shuffle_files=False,\n",
    "    split='train'\n",
    ")\n",
    "\n",
    "train_ds = input_fn(batch_size=50000, ds=train_ds)\n",
    "\n",
    "train_features, train_labels =list(train_ds.take(1))[0]\n",
    "del train_ds\n",
    "\n",
    "train_features = train_features.numpy()\n",
    "train_labels = train_labels.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_indices = np.random.choice(10000, size=10, replace=False)\n",
    "\n",
    "test_feat = test_features[test_indices]\n",
    "print(test_feat.shape)\n",
    "\n",
    "print('test indices:    {}'.format(test_indices))\n",
    "print(test_labels[test_indices])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "equal = [1.0, 1.0, 1.0, 1.0]\n",
    "low = [2.0, 1.0, 0.0, 0.0]\n",
    "high = [0.0, 0.0, 1.0, 2.0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The distances between embeddings from different layers are averaged to output visually similar images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind = acm.query(features=test_feat, weights=equal)\n",
    "ind\n",
    "\n",
    "train_image_arrays_list = train_features[ind]\n",
    "\n",
    "utils.visualize_similar(\n",
    "    test_image_arrays=test_feat,\n",
    "    train_image_arrays_list=train_image_arrays_list,\n",
    "    test_labels=test_labels[test_indices].tolist(),\n",
    "    train_labels=train_labels[ind].tolist()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind = acm.query(features=test_feat, weights=low)\n",
    "ind\n",
    "\n",
    "train_image_arrays_list = train_features[ind]\n",
    "\n",
    "utils.visualize_similar(\n",
    "    test_image_arrays=test_feat,\n",
    "    train_image_arrays_list=train_image_arrays_list,\n",
    "    test_labels=test_labels[test_indices].tolist(),\n",
    "    train_labels=train_labels[ind].tolist()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind = acm.query(features=test_feat, weights=high)\n",
    "ind\n",
    "\n",
    "train_image_arrays_list = train_features[ind]\n",
    "\n",
    "utils.visualize_similar(\n",
    "    test_image_arrays=test_feat,\n",
    "    train_image_arrays_list=train_image_arrays_list,\n",
    "    test_labels=test_labels[test_indices].tolist(),\n",
    "    train_labels=train_labels[ind].tolist()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compare with similar images based on the last activation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get last activations\n",
    "\n",
    "activation_names = [acm.activation_names[-1]]\n",
    "\n",
    "train_activations = acm.get_activations_from_features(train_features, activation_names)\n",
    "test_activations = acm.get_activations_from_features(test_features, activation_names)\n",
    "\n",
    "print(train_activations.keys(), test_activations.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# query with activations\n",
    "from scipy.spatial.distance import cdist\n",
    "\n",
    "\n",
    "def query(test_acts, train_acts, weights=None):\n",
    "    if weights is None:\n",
    "        weights = [1.0] * len(acm.activation_names)\n",
    "    distances = 0.0\n",
    "    for i, activation_name in enumerate(acm.activation_names):\n",
    "        if activation_name not in test_acts:\n",
    "            continue\n",
    "\n",
    "        test_act = test_acts[activation_name]\n",
    "        train_act = train_acts[activation_name]\n",
    "        \n",
    "        # flatten\n",
    "        test_act = test_act.reshape((len(test_act), -1))\n",
    "        train_act = train_act.reshape((len(train_act), -1))\n",
    "\n",
    "        dis = cdist(test_act, train_act, 'euclidean')\n",
    "        distances += dis * weights[i]\n",
    "\n",
    "    ind = np.argsort(distances, axis=-1)\n",
    "    \n",
    "    return ind"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_acts = {k:v[test_indices] for k, v in test_activations.items()}\n",
    "train_acts = train_activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# using only the last activation\n",
    "%time last_act_ind = query(test_acts, train_acts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using the distance between the last activations to determine which images are similar does not capture low-level visual features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_arrays_list = train_features[last_act_ind]\n",
    "\n",
    "utils.visualize_similar(\n",
    "    test_image_arrays=test_feat,\n",
    "    train_image_arrays_list=train_image_arrays_list,\n",
    "    test_labels=test_labels[test_indices].tolist(),\n",
    "    train_labels=train_labels[last_act_ind].tolist()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Concepts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we think of each cluster as a \"concept\" and visualize images closest to each cluster centroid."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concept_indices = acm.concept_indices()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each activation has its own list of clusters.  Earlier activations capture low-level visual features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation_index = 0\n",
    "print('Concepts based on {}'.format(acm.activation_names[activation_index]))\n",
    "concept_ind = concept_indices[activation_index]\n",
    "train_image_arrays_list = train_features[concept_ind]\n",
    "\n",
    "utils.visualize_concepts(train_image_arrays_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation_index = 2\n",
    "print('Concepts based on {}'.format(acm.activation_names[activation_index]))\n",
    "concept_ind = concept_indices[activation_index]\n",
    "train_image_arrays_list = train_features[concept_ind]\n",
    "\n",
    "utils.visualize_concepts(train_image_arrays_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation_index = 3\n",
    "print('Concepts based on {}'.format(acm.activation_names[activation_index]))\n",
    "concept_ind = concept_indices[activation_index]\n",
    "train_image_arrays_list = train_features[concept_ind]\n",
    "\n",
    "utils.visualize_concepts(train_image_arrays_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
